{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "239be6af",
   "metadata": {},
   "source": [
    "# Lesson 2 - Batching\n",
    "\n",
    "In this lesson, we'll discuss the concept of \"batching\" in LLM inference.\n",
    "\n",
    "- What is batching?\n",
    "- Throughput vs latency"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dd87319-4656-45f1-a6be-0a0b970e01f9",
   "metadata": {},
   "source": [
    "### Import required packages and load the LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ea434a9-22e6-4bee-8ef3-65e8992ce089",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import time\n",
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7066fad6-94df-4153-a6de-a33d7dace6b4",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "model_name = \"./models/gpt2\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6ca61e7-ad31-49ae-9338-6e38b0881dde",
   "metadata": {},
   "source": [
    "### Reuse KV-cache text generation function from Lesson 1\n",
    "- Use the same prompt as the previous lesson to verify everything is working as expected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b79f19-38f6-4f34-804d-921ceb40931b",
   "metadata": {
    "height": 641
   },
   "outputs": [],
   "source": [
    "prompt = \"The quick brown fox jumped over the\"\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "\n",
    "\n",
    "def generate_token_with_past(inputs):\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "\n",
    "    logits = outputs.logits\n",
    "    last_logits = logits[0, -1, :]\n",
    "    next_token_id = last_logits.argmax()\n",
    "    return next_token_id, outputs.past_key_values\n",
    "\n",
    "\n",
    "def generate(inputs, max_tokens):\n",
    "    generated_tokens = []\n",
    "    next_inputs = inputs\n",
    "    for _ in range(max_tokens):\n",
    "        next_token_id, past_key_values = \\\n",
    "        generate_token_with_past(next_inputs)\n",
    "        next_inputs = {\n",
    "            \"input_ids\": next_token_id.reshape((1, 1)),\n",
    "            \"attention_mask\": torch.cat(\n",
    "                [next_inputs[\"attention_mask\"], torch.tensor([[1]])],\n",
    "                dim=1\n",
    "            ),\n",
    "            \"past_key_values\": past_key_values,\n",
    "        }\n",
    "\n",
    "        next_token = tokenizer.decode(next_token_id)\n",
    "        generated_tokens.append(next_token)\n",
    "    return \"\".join(generated_tokens)\n",
    "\n",
    "\n",
    "tokens = generate(inputs, max_tokens=10)\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "802ad913-0c0e-450b-afa0-ff289b85b5f1",
   "metadata": {},
   "source": [
    "### Add padding tokens to the model to prepare batches of prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d895c2e9-fd7f-41d9-af00-a44a19d512b2",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Define PAD Token = EOS Token = 50256\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "model.config.pad_token_id = model.config.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e09245a5-d032-4a45-a2e8-dad82305b449",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# pad on the left so we can append new tokens on the right\n",
    "tokenizer.padding_side = \"left\"\n",
    "tokenizer.truncation_side = \"left\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ea8d58c-3732-4b00-9819-9d3c45e4fd04",
   "metadata": {},
   "source": [
    "- Tokenize list of prompts\n",
    "- Add padding so that all prompts have the same number of tokens as the longest prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17dc969e-eaf1-4594-b328-2fbb8dbd470a",
   "metadata": {
    "height": 199
   },
   "outputs": [],
   "source": [
    "# multiple prompts of varying lengths to send\n",
    "# to the model at once\n",
    "prompts = [\n",
    "    \"The quick brown fox jumped over the\",\n",
    "    \"The rain in Spain falls\",\n",
    "    \"What comes up must\",\n",
    "]\n",
    "\n",
    "# note: padding=True ensures the padding token\n",
    "# will be inserted into the tokenized tensors\n",
    "inputs = tokenizer(prompts, padding=True, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54650f21-92cd-4d78-b2e3-ddd3dd811687",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "print(\"input_ids:\", inputs[\"input_ids\"])\n",
    "print(\"shape:\", inputs[\"input_ids\"].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f80e4eaa-819e-4209-9154-58b77464043d",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "print(\"attention_mask:\", inputs[\"attention_mask\"])\n",
    "print(\"shape:\", inputs[\"attention_mask\"].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5264a70b-d66c-439a-873b-732579eb71c3",
   "metadata": {},
   "source": [
    "- Add position ids to track original order of tokens in each prompt\n",
    "- Padding tokens are set to `1` and then first real token starts with position `0`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a9fd5ca-7bc0-48ad-a6f4-88c29598570c",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "# position_ids tell the transformer the ordinal position\n",
    "# of each token in the input sequence\n",
    "# for single input inference, this is just [0 .. n]\n",
    "# for n tokens, but for batch inference,\n",
    "# we need to 0 out the padding tokens at the start of the sequence\n",
    "attention_mask = inputs[\"attention_mask\"]\n",
    "position_ids = attention_mask.long().cumsum(-1) - 1\n",
    "position_ids.masked_fill_(attention_mask == 0, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d68e22e-73a2-4527-9529-11cf98ea139e",
   "metadata": {},
   "source": [
    "- Pass tokens to model to calculate logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "117ed42c-cc0b-4ca8-b656-dd1484fc5574",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "# same as before, but include the position_ids\n",
    "with torch.no_grad():\n",
    "    outputs = model(position_ids=position_ids, **inputs)\n",
    "logits = outputs.logits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cb19b85-673a-4f0c-ac11-3e82b0f9ddc9",
   "metadata": {},
   "source": [
    "- Retrieve most likely token for each prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a70df683-bc0a-403b-9cbf-014cb2362f01",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "last_logits = logits[:, -1, :] \n",
    "next_token_ids = last_logits.argmax(dim=1) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ede5ecf0",
   "metadata": {},
   "source": [
    "- Print the next token ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74f61f3a-5f19-47e4-92cc-8d30ef574944",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "print(next_token_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c645c96",
   "metadata": {},
   "source": [
    "- Convert the token ids into strings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78a8061b-643f-4125-a5a7-fb6461523dfe",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "next_tokens = tokenizer.batch_decode(next_token_ids)\n",
    "next_tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad750c23",
   "metadata": {},
   "source": [
    "### Let's put it all together!\n",
    " - Generate n tokens with past"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3543bb0-2a33-49dc-9d61-5af911a03d24",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "def generate_batch_tokens_with_past(inputs):\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "\n",
    "    logits = outputs.logits\n",
    "    last_logits = logits[:, -1, :]\n",
    "    next_token_ids = last_logits.argmax(dim=1)\n",
    "    return next_token_ids, outputs.past_key_values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea81cfb2",
   "metadata": {},
   "source": [
    "- Generate all tokens for some max tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb50b1e6-415f-4315-8816-8818df516a07",
   "metadata": {
    "height": 590
   },
   "outputs": [],
   "source": [
    "def generate_batch(inputs, max_tokens):\n",
    "    # create a list of tokens for every input in the batch\n",
    "    generated_tokens = [\n",
    "        [] for _ in range(inputs[\"input_ids\"].shape[0])\n",
    "    ]\n",
    "\n",
    "    attention_mask = inputs[\"attention_mask\"]\n",
    "    position_ids = attention_mask.long().cumsum(-1) - 1\n",
    "    position_ids.masked_fill_(attention_mask == 0, 1)\n",
    "\n",
    "    next_inputs = {\n",
    "        \"position_ids\": position_ids,\n",
    "        **inputs\n",
    "    }\n",
    "\n",
    "    for _ in range(max_tokens):\n",
    "        next_token_ids, past_key_values = \\\n",
    "            generate_batch_tokens_with_past(next_inputs)\n",
    "\n",
    "        next_inputs = {\n",
    "            \"input_ids\": next_token_ids.reshape((-1, 1)),\n",
    "            \"position_ids\": next_inputs[\"position_ids\"][:, -1].unsqueeze(-1) + 1,\n",
    "            \"attention_mask\": torch.cat([\n",
    "                next_inputs[\"attention_mask\"],\n",
    "                torch.ones((next_token_ids.shape[0], 1)),  \n",
    "            ], dim=1),\n",
    "            \"past_key_values\": past_key_values,\n",
    "        }\n",
    "\n",
    "        next_tokens = tokenizer.batch_decode(next_token_ids)\n",
    "        for i, token in enumerate(next_tokens):\n",
    "            generated_tokens[i].append(token)\n",
    "    return [\"\".join(tokens) for tokens in generated_tokens]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c8845f1",
   "metadata": {},
   "source": [
    "- Call the generate_batch function and print out the generated tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bd8d5cc-c06c-4efc-97da-ebcd3b2e15b6",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "generated_tokens = generate_batch(inputs, max_tokens=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa48f7c0-6633-400e-aa11-e74000497b65",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "for prompt, generated in zip(prompts, generated_tokens):\n",
    "    print(prompt, f\"\\x1b[31m{generated}\\x1b[0m\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e077e2c0",
   "metadata": {},
   "source": [
    "### Throughput vs Latency\n",
    "\n",
    "- Explore the effect of batching on latency (how long it takes to generate each token). \n",
    "- Observe the fundamental tradeoff that exists between throughput and latency.\n",
    "\n",
    "**Note:** Your results might differ somewhat from those shown in the video, but they will still follow the same pattern as explained by the instructor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35e3a944-e2f3-4cda-a5e3-d753b223e73a",
   "metadata": {
    "height": 607
   },
   "outputs": [],
   "source": [
    "# constants\n",
    "max_tokens = 10\n",
    "\n",
    "# observations\n",
    "durations = []\n",
    "throughputs = []\n",
    "latencies = []\n",
    "\n",
    "batch_sizes = [2**p for p in range(8)]\n",
    "for batch_size in batch_sizes:\n",
    "    print(f\"bs= {batch_size}\")\n",
    "\n",
    "    # generate tokens for batch and record duration\n",
    "    t0 = time.time()\n",
    "    batch_prompts = [\n",
    "        prompts[i % len(prompts)] for i in range(batch_size)\n",
    "    ]\n",
    "    inputs = tokenizer(\n",
    "        batch_prompts, padding=True, return_tensors=\"pt\"\n",
    "    )\n",
    "    generated_tokens = generate_batch(inputs, max_tokens=max_tokens)\n",
    "    duration_s = time.time() - t0\n",
    "\n",
    "    ntokens = batch_size * max_tokens\n",
    "    throughput = ntokens / duration_s\n",
    "    avg_latency = duration_s / max_tokens\n",
    "    print(\"duration\", duration_s)\n",
    "    print(\"throughput\", throughput)\n",
    "    print(\"avg latency\", avg_latency)    \n",
    "    print()\n",
    "\n",
    "    durations.append(duration_s)\n",
    "    throughputs.append(throughput)\n",
    "    latencies.append(avg_latency)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ff34a13",
   "metadata": {},
   "source": [
    "### Let's plot the throughput and latency observations against the batch size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "898c44fc-e5ba-48dd-8d80-aae9eddec98b",
   "metadata": {
    "height": 403
   },
   "outputs": [],
   "source": [
    "def render_plot(x, y1, y2, x_label, y1_label, y2_label):\n",
    "    # Create a figure and a set of subplots\n",
    "    fig, ax1 = plt.subplots()\n",
    "\n",
    "    # Plot the first line (throughput)\n",
    "    color = 'tab:red'\n",
    "    ax1.set_xlabel(x_label)\n",
    "    ax1.set_ylabel(y1_label, color=color)\n",
    "    ax1.plot(x, y1, color=color)\n",
    "    ax1.tick_params(axis='y', labelcolor=color)\n",
    "\n",
    "    # Set the x-axis to be log-scaled\n",
    "    ax1.set_xscale('log', base=2)\n",
    "\n",
    "    # Instantiate a second axes that shares the same x-axis\n",
    "    ax2 = ax1.twinx()  \n",
    "    color = 'tab:blue'\n",
    "    ax2.set_ylabel(y2_label, color=color)  # we already handled the x-label with ax1\n",
    "    ax2.plot(x, y2, color=color)\n",
    "    ax2.tick_params(axis='y', labelcolor=color)\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6c99edb",
   "metadata": {},
   "source": [
    "**Note**: Your plot may vary slightly from the one shown in the video, yet it will exhibit a similar pattern."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e2f578-bb58-404a-9611-79b7cf79ed09",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "render_plot(\n",
    "    batch_sizes,\n",
    "    throughputs,\n",
    "    latencies,\n",
    "    \"Batch Size\",\n",
    "    \"Throughput\",\n",
    "    \"Latency\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
